{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jv-tHPvR-JKa"
      },
      "source": [
        "# Graph Transformer in a Nutshell\n",
        "\n",
        "The **Transformer** [(Vaswani et al. 2017)](https://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) has been proven an effective learning architecture in natural language processing and computer vision.\n",
        "Recently, researchers turns to explore the application of transformer in graph learning. They have achieved inital success on many practical tasks, e.g., graph property prediction.\n",
        "[Dwivedi et al. (2020)](https://arxiv.org/abs/2012.09699) firstly generalize the transformer neural architecture to graph-structured data. Here, we present how to build such a graph transformer with DGL's sparse matrix APIs.\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "os.environ['DGLBACKEND'] = \"pytorch\"\n",
        "\n",
        "# Uncomment below to install required packages. If the CUDA version is not 11.8,\n",
        "# check the https://www.dgl.ai/pages/start.html to find the supported CUDA\n",
        "# version and corresponding command to install DGL.\n",
        "#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null\n",
        "#!pip install ogb >/dev/null\n",
        "\n",
        "try:\n",
        "    import dgl\n",
        "    installed = True\n",
        "except ImportError:\n",
        "    installed = False\n",
        "print(\"DGL installed!\" if installed else \"Failed to install DGL!\")"
      ],
      "metadata": {
        "id": "8wIJZQqODy-7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nOpFdtLI-JKb"
      },
      "source": [
        "## Sparse Multi-head Attention\n",
        "\n",
        "Recall the all-pairs scaled-dot-product attention mechanism in vanillar Transformer:\n",
        "\n",
        "$$\\text{Attn}=\\text{softmax}(\\dfrac{QK^T} {\\sqrt{d}})V,$$\n",
        "\n",
        "The graph transformer (GT) model employs a Sparse Multi-head Attention block:\n",
        "\n",
        "$$\\text{SparseAttn}(Q, K, V, A) = \\text{softmax}(\\frac{(QK^T) \\circ A}{\\sqrt{d}})V,$$\n",
        "\n",
        "where $Q, K, V ∈\\mathbb{R}^{N\\times d}$ are query feature, key feature, and value feature, respectively. $A\\in[0,1]^{N\\times N}$ is the adjacency matrix of the input graph. $(QK^T)\\circ A$ means that the multiplication of query matrix and key matrix is followed by a Hadamard product (or element-wise multiplication) with the sparse adjacency matrix as illustrated in the figure below:\n",
        "\n",
        "<img src=\"https://drive.google.com/uc?id=1OgMAewLR3Z1vz5y4J8aPRSeaU3g8iQfX\" width=\"500\">\n",
        "\n",
        "Essentially, only the attention scores between connected nodes are computed according to the sparsity of $A$. This operation is also called *Sampled Dense Dense Matrix Multiplication (SDDMM)*.\n",
        "\n",
        "Enjoying the [batched SDDMM API](https://docs.dgl.ai/en/latest/generated/dgl.sparse.bsddmm.html) in DGL, we can parallel the computation on multiple attention heads (different representation subspaces).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dh7zc5v0-JKb"
      },
      "outputs": [],
      "source": [
        "import dgl\n",
        "import dgl.nn as dglnn\n",
        "import dgl.sparse as dglsp\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "from dgl.data import AsGraphPredDataset\n",
        "from dgl.dataloading import GraphDataLoader\n",
        "from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator\n",
        "from ogb.graphproppred.mol_encoder import AtomEncoder\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "class SparseMHA(nn.Module):\n",
        "    \"\"\"Sparse Multi-head Attention Module\"\"\"\n",
        "\n",
        "    def __init__(self, hidden_size=80, num_heads=8):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "        self.num_heads = num_heads\n",
        "        self.head_dim = hidden_size // num_heads\n",
        "        self.scaling = self.head_dim**-0.5\n",
        "\n",
        "        self.q_proj = nn.Linear(hidden_size, hidden_size)\n",
        "        self.k_proj = nn.Linear(hidden_size, hidden_size)\n",
        "        self.v_proj = nn.Linear(hidden_size, hidden_size)\n",
        "        self.out_proj = nn.Linear(hidden_size, hidden_size)\n",
        "\n",
        "    def forward(self, A, h):\n",
        "        N = len(h)\n",
        "        # [N, dh, nh]\n",
        "        q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
        "        q *= self.scaling\n",
        "        # [N, dh, nh]\n",
        "        k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
        "        # [N, dh, nh]\n",
        "        v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
        "\n",
        "        ######################################################################\n",
        "        # (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API\n",
        "        ######################################################################\n",
        "        attn = dglsp.bsddmm(A, q, k.transpose(1, 0))  # (sparse) [N, N, nh]\n",
        "        # Sparse softmax by default applies on the last sparse dimension.\n",
        "        attn = attn.softmax()  # (sparse) [N, N, nh]\n",
        "        out = dglsp.bspmm(attn, v)  # [N, dh, nh]\n",
        "\n",
        "        return self.out_proj(out.reshape(N, -1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3_Fm6Lrx-JKc"
      },
      "source": [
        "## Graph Transformer Layer\n",
        "\n",
        "The GT layer is composed of Multi-head Attention, Batch Norm, and Feed-forward Network, connected by residual links as in vanilla transformer.\n",
        "\n",
        "<img src=\"https://drive.google.com/uc?id=1cm-Ijw7bUQIOkoTKn5MQ3m4-66JqCsMz\" width=\"300\">"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M6h7JVWT-JKd"
      },
      "outputs": [],
      "source": [
        "class GTLayer(nn.Module):\n",
        "    \"\"\"Graph Transformer Layer\"\"\"\n",
        "\n",
        "    def __init__(self, hidden_size=80, num_heads=8):\n",
        "        super().__init__()\n",
        "        self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)\n",
        "        self.batchnorm1 = nn.BatchNorm1d(hidden_size)\n",
        "        self.batchnorm2 = nn.BatchNorm1d(hidden_size)\n",
        "        self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)\n",
        "        self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)\n",
        "\n",
        "    def forward(self, A, h):\n",
        "        h1 = h\n",
        "        h = self.MHA(A, h)\n",
        "        h = self.batchnorm1(h + h1)\n",
        "\n",
        "        h2 = h\n",
        "        h = self.FFN2(F.relu(self.FFN1(h)))\n",
        "        h = h2 + h\n",
        "\n",
        "        return self.batchnorm2(h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t40DhVjI-JKd"
      },
      "source": [
        "## Graph Transformer Model\n",
        "\n",
        "The GT model is constructed by stacking GT layers. The input positional encoding of vanilla transformer is replaced with Laplacian positional encoding [(Dwivedi et al. 2020)](https://arxiv.org/abs/2003.00982). For the graph-level prediction task, an extra pooler is stacked on top of GT layers to aggregate node feature of the same graph."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UrjvEBrF-JKe"
      },
      "outputs": [],
      "source": [
        "class GTModel(nn.Module):\n",
        "    def __init__(\n",
        "        self,\n",
        "        out_size,\n",
        "        hidden_size=80,\n",
        "        pos_enc_size=2,\n",
        "        num_layers=8,\n",
        "        num_heads=8,\n",
        "    ):\n",
        "        super().__init__()\n",
        "        self.atom_encoder = AtomEncoder(hidden_size)\n",
        "        self.pos_linear = nn.Linear(pos_enc_size, hidden_size)\n",
        "        self.layers = nn.ModuleList(\n",
        "            [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]\n",
        "        )\n",
        "        self.pooler = dglnn.SumPooling()\n",
        "        self.predictor = nn.Sequential(\n",
        "            nn.Linear(hidden_size, hidden_size // 2),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(hidden_size // 2, hidden_size // 4),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(hidden_size // 4, out_size),\n",
        "        )\n",
        "\n",
        "    def forward(self, g, X, pos_enc):\n",
        "        indices = torch.stack(g.edges())\n",
        "        N = g.num_nodes()\n",
        "        A = dglsp.spmatrix(indices, shape=(N, N))\n",
        "        h = self.atom_encoder(X) + self.pos_linear(pos_enc)\n",
        "        for layer in self.layers:\n",
        "            h = layer(A, h)\n",
        "        h = self.pooler(g, h)\n",
        "\n",
        "        return self.predictor(h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RdrPU18I-JKe"
      },
      "source": [
        "## Training\n",
        "\n",
        "We train the GT model on [ogbg-molhiv](https://ogb.stanford.edu/docs/graphprop/#ogbg-mol) benchmark. The Laplacian positional encoding of each graph is pre-computed (with the API [here](https://docs.dgl.ai/en/latest/generated/dgl.laplacian_pe.html)) as part of the input to the model.\n",
        "\n",
        "*Note that we down-sample the dataset to make this demo runs faster. See the* [*example script*](https://github.com/dmlc/dgl/blob/master/examples/sparse/graph_transformer.py) *for the performance on the full dataset.*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V41i0w-9-JKe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "15343d1a-a32d-4677-d053-d9da96910f43"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Computing Laplacian PE:   1%|          | 25/4000 [00:00<00:16, 244.77it/s]/usr/local/lib/python3.8/dist-packages/dgl/backend/pytorch/tensor.py:52: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:250.)\n",
            "  return th.as_tensor(data, dtype=dtype)\n",
            "Computing Laplacian PE: 100%|██████████| 4000/4000 [00:13<00:00, 296.04it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 000, Loss: 0.2486, Val: 0.3082, Test: 0.3068\n",
            "Epoch: 001, Loss: 0.1695, Val: 0.4684, Test: 0.4572\n",
            "Epoch: 002, Loss: 0.1428, Val: 0.5887, Test: 0.4721\n",
            "Epoch: 003, Loss: 0.1237, Val: 0.6375, Test: 0.5010\n",
            "Epoch: 004, Loss: 0.1127, Val: 0.6628, Test: 0.4854\n",
            "Epoch: 005, Loss: 0.1047, Val: 0.6811, Test: 0.4983\n",
            "Epoch: 006, Loss: 0.0949, Val: 0.6751, Test: 0.5409\n",
            "Epoch: 007, Loss: 0.0901, Val: 0.6340, Test: 0.5357\n",
            "Epoch: 008, Loss: 0.0811, Val: 0.6717, Test: 0.5543\n",
            "Epoch: 009, Loss: 0.0643, Val: 0.7861, Test: 0.5628\n",
            "Epoch: 010, Loss: 0.0489, Val: 0.7319, Test: 0.5341\n",
            "Epoch: 011, Loss: 0.0340, Val: 0.7884, Test: 0.5299\n",
            "Epoch: 012, Loss: 0.0285, Val: 0.5887, Test: 0.4293\n",
            "Epoch: 013, Loss: 0.0361, Val: 0.5514, Test: 0.3419\n",
            "Epoch: 014, Loss: 0.0451, Val: 0.6795, Test: 0.4964\n",
            "Epoch: 015, Loss: 0.0429, Val: 0.7405, Test: 0.5527\n",
            "Epoch: 016, Loss: 0.0331, Val: 0.7859, Test: 0.4994\n",
            "Epoch: 017, Loss: 0.0177, Val: 0.6544, Test: 0.4457\n",
            "Epoch: 018, Loss: 0.0201, Val: 0.8250, Test: 0.6073\n",
            "Epoch: 019, Loss: 0.0093, Val: 0.7356, Test: 0.5561\n"
          ]
        }
      ],
      "source": [
        "@torch.no_grad()\n",
        "def evaluate(model, dataloader, evaluator, device):\n",
        "    model.eval()\n",
        "    y_true = []\n",
        "    y_pred = []\n",
        "    for batched_g, labels in dataloader:\n",
        "        batched_g, labels = batched_g.to(device), labels.to(device)\n",
        "        y_hat = model(batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"])\n",
        "        y_true.append(labels.view(y_hat.shape).detach().cpu())\n",
        "        y_pred.append(y_hat.detach().cpu())\n",
        "    y_true = torch.cat(y_true, dim=0).numpy()\n",
        "    y_pred = torch.cat(y_pred, dim=0).numpy()\n",
        "    input_dict = {\"y_true\": y_true, \"y_pred\": y_pred}\n",
        "    return evaluator.eval(input_dict)[\"rocauc\"]\n",
        "\n",
        "\n",
        "def train(model, dataset, evaluator, device):\n",
        "    train_dataloader = GraphDataLoader(\n",
        "        dataset[dataset.train_idx],\n",
        "        batch_size=256,\n",
        "        shuffle=True,\n",
        "        collate_fn=collate_dgl,\n",
        "    )\n",
        "    valid_dataloader = GraphDataLoader(\n",
        "        dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl\n",
        "    )\n",
        "    test_dataloader = GraphDataLoader(\n",
        "        dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl\n",
        "    )\n",
        "    optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
        "    num_epochs = 20\n",
        "    scheduler = optim.lr_scheduler.StepLR(\n",
        "        optimizer, step_size=num_epochs, gamma=0.5\n",
        "    )\n",
        "    loss_fcn = nn.BCEWithLogitsLoss()\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        model.train()\n",
        "        total_loss = 0.0\n",
        "        for batched_g, labels in train_dataloader:\n",
        "            batched_g, labels = batched_g.to(device), labels.to(device)\n",
        "            logits = model(\n",
        "                batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"]\n",
        "            )\n",
        "            loss = loss_fcn(logits, labels.float())\n",
        "            total_loss += loss.item()\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "        scheduler.step()\n",
        "        avg_loss = total_loss / len(train_dataloader)\n",
        "        val_metric = evaluate(model, valid_dataloader, evaluator, device)\n",
        "        test_metric = evaluate(model, test_dataloader, evaluator, device)\n",
        "        print(\n",
        "            f\"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, \"\n",
        "            f\"Val: {val_metric:.4f}, Test: {test_metric:.4f}\"\n",
        "        )\n",
        "\n",
        "\n",
        "# Training device.\n",
        "dev = torch.device(\"cpu\")\n",
        "# Uncomment the code below to train on GPU. Be sure to install DGL with CUDA support.\n",
        "#dev = torch.device(\"cuda:0\")\n",
        "\n",
        "# Load dataset.\n",
        "pos_enc_size = 8\n",
        "dataset = AsGraphPredDataset(\n",
        "    DglGraphPropPredDataset(\"ogbg-molhiv\", \"./data/OGB\")\n",
        ")\n",
        "evaluator = Evaluator(\"ogbg-molhiv\")\n",
        "\n",
        "# Down sample the dataset to make the tutorial run faster.\n",
        "import random\n",
        "random.seed(42)\n",
        "train_size = len(dataset.train_idx)\n",
        "val_size = len(dataset.val_idx)\n",
        "test_size = len(dataset.test_idx)\n",
        "dataset.train_idx = dataset.train_idx[\n",
        "    torch.LongTensor(random.sample(range(train_size), 2000))\n",
        "]\n",
        "dataset.val_idx = dataset.val_idx[\n",
        "    torch.LongTensor(random.sample(range(val_size), 1000))\n",
        "]\n",
        "dataset.test_idx = dataset.test_idx[\n",
        "    torch.LongTensor(random.sample(range(test_size), 1000))\n",
        "]\n",
        "\n",
        "# Laplacian positional encoding.\n",
        "indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])\n",
        "for idx in tqdm(indices, desc=\"Computing Laplacian PE\"):\n",
        "    g, _ = dataset[idx]\n",
        "    g.ndata[\"PE\"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)\n",
        "\n",
        "# Create model.\n",
        "out_size = dataset.num_tasks\n",
        "model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)\n",
        "\n",
        "# Kick off training.\n",
        "train(model, dataset, evaluator, dev)"
      ]
    }
  ],
  "metadata": {
    "language_info": {
      "name": "python"
    },
    "orig_nbformat": 4,
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
